Load Packages

library(MMCTime)
#> Warning: replacing previous import 'ape::where' by 'dplyr::where' when loading
#> 'MMCTime'
#> Warning: replacing previous import 'ape::rotate' by 'ggtree::rotate' when
#> loading 'MMCTime'
library(ggplot2)
#> Use suppressPackageStartupMessages() to eliminate package startup
#> messages
library(ape)
library(posterior)
#> This is posterior version 1.5.0
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var
#> The following objects are masked from 'package:base':
#> 
#>     %in%, match
library(reshape2)
library(patchwork)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following object is masked from 'package:ape':
#> 
#>     where
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(phangorn)

Ingest Data

phi_sets <- 4L
mut_sets <- 3L
set_size <- 48L

n_tips<-200

nu <- 1/10
alphas <- seq(from=0.01, to=.75, length.out=set_size)
phis <- c(0.0, 0.5, 0.75, 1.0)
mus <- c(1.5, 3.0, 6.0)
omegas <- c(0.5, 1.0, 2.0)
root_n <- "t_201"
height_n <- "rel_height"
len_n <- "rel_length"

f <- function(clock, phi_idx, run_idx)
{
    idx <- (phi_idx-1L) * set_size + run_idx
    res<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
    tmp <- as.data.frame(res$summaries[,c("variable", "median", "q5", "q95", "rhat", "ess_bulk")])
    tmp$run_idx <- run_idx
    tmp$phi_idx <- phi_idx
    tmp$clock <- clock

    g <- function(x)
    {
        y <- node.depth.edgelength(read.tree(paste0("./gt/tree_",x,".nwk")))
        return(max(y))
    }

    g_len <- function(x)
    {
        y <- read.tree(paste0("./gt/tree_",x,".nwk"))$edge.length
        return(sum(y))
    }

    gt_len <- g_len(idx)
    gt_h <- g(idx)

    tmp[tmp$variable == root_n, c("q5","median","q95")] <- (tmp[tmp$variable == root_n, c("q5","median","q95")]-gt_h)/gt_h
    tmp[tmp$variable == root_n,"variable"] = height_n

    tmp[tmp$variable == "tree_length", c("q5","median","q95")] <- (tmp[tmp$variable == "tree_length", c("q5","median","q95")]-gt_len)/gt_len
    tmp[tmp$variable == "tree_length", "variable"] = len_n

    return(tmp)
}

dfs <- do.call(rbind, lapply(1L:mut_sets,
    function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
        function(j) lapply(1L:set_size,
            function(k) f(i,j,k)))))))

Plot Stuff

vnames <- c("phi", "alpha", "nu","omega","mu", height_n, len_n)

subset_v <- function (df, v) df[df$variable==v,]

relab <- function (df) df %>%
  mutate(phi_idx = recode(phi_idx, "1" = "0.0", "2" = "0.5", "3" = "0.75", "4"="1.0"))

conv_df <- do.call(rbind, lapply(1L:mut_sets,
    function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
        function(j) lapply(1L:set_size,
            function(k) data.frame(clock=i,phi_idx=j, run_idx=k)))))))
conv_df$conv <- apply(conv_df, 1, function(x)  all((dfs[which((dfs$clock == x[1]) & (dfs$phi_idx == x[2]) &  (dfs$run_idx == x[3])),]$ess_bulk>200)) && 
        all(dfs[which((dfs$clock == x[1]) & (dfs$phi_idx == x[2]) &  (dfs$run_idx == x[3])),]$rhat < 1.05))
conv_df <- conv_df[!conv_df$conv, ]
conv_df <- relab(conv_df)


sum_plt <- function(df, vname)
{
    allDf <- df[df$variable %in% vnames,]
    
    gt_df <- do.call(rbind, lapply(1L:mut_sets,
        function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
            function(j) lapply(1L:set_size,
                function(k) data.frame(clock=i,phi_idx=j, run_idx=k)))))))
    gt_df$phi <- apply(gt_df, 1, function(x) phis[x[2]])
    gt_df$alpha <- apply(gt_df, 1, function(x) alphas[x[3]])

    gt_df$mu <- apply(gt_df, 1, function(x) mus[x[1]])
    gt_df$omega <- apply(gt_df, 1, function(x) omegas[x[1]])

    gt_df$nu <- nu
    gt_df$rel_height <- 0.0
    gt_df$rel_length <- 0.0

    gt_df <- melt(gt_df, measure.vars=vnames)
    gt_df$variable <- factor(gt_df$variable, levels=vnames)
    allDf$variable <- factor(allDf$variable, levels=vnames)

    gt_df <- subset_v(gt_df, vname)
    allDf <- subset_v(allDf, vname)

    allDf <- relab(allDf)
    gt_df <- relab(gt_df)

    ggplot(allDf, aes(x=run_idx, ymin=q5, y=median, ymax=q95)) +
    geom_errorbar(width=.1, , alpha=0.4) +
    geom_point(size=.5, alpha=0.8) + 
    geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
    geom_line(data=gt_df, aes(x=run_idx, y=value),color="red", inherit.aes=F) +
    facet_grid(rows = vars(phi_idx), cols=vars(clock),labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
    theme_minimal()+
    labs(x="Run") +
    coord_cartesian(clip = 'off') +
    ggtitle(vname) +
    theme(
        axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
        plot.margin = margin(0, 0, 0, 0, "cm"),
        panel.grid.major = element_blank(), 
        axis.line = element_line(size=rel(0.2), colour = "grey80"),
        plot.title = element_text(hjust = 0.5,size=rel(1.0))) 
}
p <- sum_plt(dfs, "phi")
#> Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
#> ℹ Please use the `linewidth` argument instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.
pdf("../manuscript_figs/kmb_phi.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, "alpha")
pdf("../manuscript_figs/kmb_alpha.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, "nu")
pdf("../manuscript_figs/kmb_nu.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, "mu")
pdf("../manuscript_figs/kmb_mu.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, "omega")
pdf("../manuscript_figs/kmb_omega.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, height_n)
pdf(paste0("../manuscript_figs/kmb_", height_n, ".pdf"),8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

p <- sum_plt(dfs, len_n)
pdf(paste0("../manuscript_figs/kmb_", len_n, ".pdf"),8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p

h <- function(clock, phi_idx, run_idx)
{
    idx <- (phi_idx-1L) * set_size + run_idx
    res<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
    q_idx <- which(colnames(res$draws) %in% paste0("q_",1:(2*n_tips-1)))
    bcounts <- apply(res$draws, 1, function(x) sum(1-x[q_idx])) - 1 ##root has a q but no branch

    g <- function(x)
    {
        nrow(di2multi(read.tree(paste0("./gt/tree_",x,".nwk")))$edge)
    }

    gt_br <- g(idx)
    b_ci <- quantile(bcounts-gt_br, probs = c(0.025, .5, 0.975))
    p_mm <- sum(((2*n_tips-2)-bcounts) > 0)/length(bcounts)

    out <- c(b_ci, p_mm, run_idx, phi_idx, clock)
    names(out) <- c("q5", "median", "q95", "p_mm", "run_idx", "phi_idx", "clock")
    return(out)
}
df_bci <- data.frame(do.call(rbind, lapply(1L:mut_sets,
    function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
        function(j) lapply(1L:set_size,
            function(k) h(i,j,k))))))))

df_bci <- relab(df_bci)

p2 <- ggplot(relab(df_bci), aes(x=run_idx, ymin=q5, y=median, ymax=q95)) +
    geom_errorbar(width=.1, alpha=0.4) +
    geom_point(size=.5, alpha=0.8) + 
    geom_hline(yintercept=0, color="red") +
    geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
    facet_grid(rows = vars(phi_idx), cols=vars(clock), labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
    geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
    theme_minimal()+
    labs(x="Run") +
    coord_cartesian(clip = 'off') +
    theme(
        axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
        axis.text.y=element_text(size=rel(0.9)),
        plot.margin = margin(0, 0, 0, 0, "cm"),
        panel.grid.major = element_blank(), 
        axis.line = element_line(size=rel(0.2), colour = "grey80"),
        plot.title = element_text(hjust = 0.5,size=rel(1.0)))


p3 <- ggplot(df_bci, aes(x=run_idx, y=p_mm)) +
    geom_point(size=.5, alpha=0.8) + 
    geom_hline(yintercept=.99, color="red") +    
    geom_hline(yintercept=.95, color="red") +
    geom_hline(yintercept=.90, color="red") +
    geom_hline(yintercept=.75, color="red") +
    geom_hline(yintercept=.5, color="red") +
    geom_point(data=conv_df, aes(x=run_idx, y=0), color="red", size=2.0, shape=4, inherit.aes=F) +
    facet_grid(rows = vars(phi_idx), cols=vars(clock), labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
    scale_y_continuous(trans='log2', breaks = c(.99,.95,.90,.75,.5)) + 
    coord_cartesian(clip = 'off') +
    theme_minimal()+
    labs(x="Run") +
    theme(
        axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
        axis.text.y=element_text(size=rel(0.9)),
        plot.margin = margin(0, 0, 0, 0, "cm"),
        panel.grid.major = element_blank(), 
        axis.line = element_line(size=rel(0.2), colour = "grey80"),
        plot.title = element_text(hjust = 0.5,size=rel(1.0)))

Plot relative branch counts

pdf("../manuscript_figs/kmb_bcount.pdf",8,8)
plot(p2)
dev.off()
#> quartz_off_screen 
#>                 2
p2

Plot Posterior Probabilities of the Tree Containing Multiple Mergers

pdf("../manuscript_figs/kmb_pmm.pdf",8,8)
plot(p3)
#> Warning in scale_y_continuous(trans = "log2", breaks = c(0.99, 0.95, 0.9, :
#> log-2 transformation introduced infinite values.
dev.off()
#> quartz_off_screen 
#>                 2
p3
#> Warning in scale_y_continuous(trans = "log2", breaks = c(0.99, 0.95, 0.9, :
#> log-2 transformation introduced infinite values.

count_subs <- function(clock, phi_idx, run_idx)
{
    idx <- (phi_idx-1L) * set_size + run_idx
    tr<-read.tree(paste0("./mut",clock,"/tree_clock_", idx, ".nwk"))
    exp_subs <- node.depth.edgelength(tr)[1:n_tips]
    out <- c(mean(exp_subs), run_idx, clock)
    names(out) <- c("exp_subs", "run_idx", "clock")
    return(out)
}

subs_df <- data.frame(do.call(rbind, lapply(1L:mut_sets,
    function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
        function(j) lapply(1L:set_size,
            function(k) count_subs(i,j,k))))))))

p4 <- ggplot(subs_df, aes(factor(clock), exp_subs/1e4L)) + 
    geom_violin() +
    theme_minimal() +
    labs(y="Expected Substitutions Per Site", x="Clock") + 
    theme(
        axis.text.x=element_text(size=rel(1.0), angle = 45, hjust=1),
        plot.margin = margin(0, 0, 0, 0, "cm"),
        panel.grid.major = element_blank(), 
        axis.line = element_line(size=rel(0.2), colour = "grey80"),
        plot.title = element_text(hjust = 0.5,size=rel(1.0)))

Plot Expected Number of Substitutions for Each Clock

pdf("../manuscript_figs/kmb_exp_subs.pdf",8,8)
plot(p4)
dev.off()
#> quartz_off_screen 
#>                 2
p4

Plot comparison against TreeTime and LSD2

h2 <- function(clock, phi_idx, run_idx)
{
    idx <- (phi_idx-1L) * set_size + run_idx
    res_mmc<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
    res_lsd<-readRDS(paste0("./mut",clock,"/ana_lsd/res_", idx, ".rds"))
    
    mmc_sums <- res_mmc$summaries

    lsd_tree <- res_lsd$dateNexusTree@phylo
    
    gt_t <- di2multi(read.tree(paste0("./gt/tree_",idx,".nwk")))

    gt_height <- max(node.depth.edgelength(gt_t))
    gt_length <- sum(gt_t$edge.length)
    gt_bcount <- nrow(gt_t$edge)

    q_idx <- which(colnames(res_mmc$draws) %in% paste0("q_",1:(2*n_tips-1)))
    
    bcounts_mmc <- (median(apply(res_mmc$draws, 1, function(x) sum(1-x[q_idx])) - 1)-gt_bcount)/gt_bcount ##root has a q but no branch
    
    tt_res <- paste0("./mut",clock,"/ana_treetime/tree_", idx, "/timetree.nexus")
    tt_tree <- NA
    bcounts_tt <- NA
    tl_tt <- NA
    height_tt <- NA
    if(file.exists(tt_res))
    {
        tt_tree <- read.nexus(tt_res)
        tl_tt <- (sum(tt_tree$edge.length) - gt_length)/gt_length
        height_tt <- (max(node.depth.edgelength(tt_tree)) - gt_height)/gt_height
        bcounts_tt <- (nrow(di2multi(tt_tree)$edge)-gt_bcount)/gt_bcount
    } else {
       print(paste("Treetime run:",run_idx, "clock:",clock, "failed to complete." ))
    }
    bcounts_lsd <- (nrow(di2multi(lsd_tree)$edge)-gt_bcount)/gt_bcount
    tl_lsd <- (sum(lsd_tree$edge.length) - gt_length)/gt_length
    height_lsd <- (max(node.depth.edgelength(lsd_tree)) - gt_height)/gt_height
    bcounts <- c(bcounts_mmc, bcounts_tt, bcounts_lsd)

    tl_mmc <- unlist(mmc_sums[mmc_sums$variable == "tree_length", "median"] - gt_length)/gt_length

    tl <- c(tl_mmc, tl_tt, tl_lsd)

    height_mmc <- unlist(mmc_sums[mmc_sums$variable == root_n, "median"]-gt_height)/gt_height

    height <- c(height_mmc, height_tt, height_lsd)

    avg_dist_mmc <- mean(sapply(sample_timetree(res_mmc, n_samp=res_mmc$n_draws, replace=F), function(x) KF.dist(x, gt_t, rooted=T)))
    dist <- c(
        avg_dist_mmc, KF.dist(di2multi(tt_tree),gt_t, rooted=T), KF.dist(di2multi(lsd_tree),gt_t, rooted=T)
    )

    out <- data.frame(value=c(bcounts, tl, height, dist), 
        variable=c(rep("rel_branch_count",3),rep("rel_length",3),rep("rel_height",3), rep("dist",3)),
        method=rep(c("ours", "treetime", "lsd2"), 4)
    )

    out$run_idx <- run_idx
    out$phi_idx <- phi_idx
    out$clock <- clock

    return(out)
}

df_summs <- data.frame(do.call(rbind, lapply(1L:mut_sets,
    function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
        function(j) lapply(1L:set_size,
            function(k) h2(i,j,k))))))))
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both

.mod_transform <- function(y, lambda){
   if(lambda != 0){
      yt <- sign(y) * (((abs(y) + 1) ^ lambda - 1) / lambda)
   } else {
      yt = sign(y) * (log(abs(y) + 1))
   }
   return(yt)
}
.mod_inverse <- function(yt, lambda){
   if(lambda != 0){
      y <- ((abs(yt) * lambda + 1)  ^ (1 / lambda) - 1) * sign(yt)
   } else {
      y <- (exp(abs(yt)) - 1) * sign(yt)
      
   }
   return(y)
}
prettify <- function(breaks){
   # round numbers, more aggressively the larger they are
   digits <- -floor(log10(abs(breaks))) + 1
   digits[breaks == 0] <- 0
   return(round(breaks, digits = digits))
}
mod_breaks <- function(lambda, n = 6, prettify = TRUE){
   function(x){
      breaks <- .mod_transform(x, lambda) %>%
         pretty(n = n) %>%
         .mod_inverse(lambda)
      if(prettify){
         breaks <- prettify(breaks)
      }
      return(breaks)
   }
}

sum_func <- function(x) 
{
    x <- unname(quantile(x, c(.025, .5, .975)))
    return(data.frame(y=x[2], ymin=x[1], ymax=x[3]))
}

foo <- function(df, vn)
{
    df %>% 
    filter(variable== vn) %>% 
    ggplot(aes(x=method, y=value, fill=method)) +
        geom_violin(aes(color=method)) +
        stat_summary(fun.data=sum_func, geom="pointrange", width=0.05, color="gray35",fill="gray35") +
        geom_hline(yintercept=0, color="red") +
        facet_grid(cols=vars(clock), rows=vars(phi_idx), scales="free_y", labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
        theme_minimal() +
        labs(x="Method", y="Relative Error") +
        ggtitle(vn)+
        scale_y_continuous(transform=scales::transform_modulus(0), breaks=mod_breaks(lambda = 0, prettify = T))+
        theme(
            axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
            plot.margin = margin(0, 0, 0, 0, "cm"),
            panel.grid.major = element_blank(), 
            axis.line = element_line(size=rel(0.2), colour = "grey80"),
            plot.title = element_text(hjust = 0.5,size=rel(1.0)))
}
p <- foo(df_summs, "rel_branch_count")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_bcount.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p 

p <- foo(df_summs, "rel_length")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_length.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p 

p <- foo(df_summs, "rel_height")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_height.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen 
#>                 2
p 

p_dists <- df_summs %>%
    filter(variable== "dist") %>% 
    ggplot(aes(x=method, y=value, fill=method)) +
        geom_violin(aes(color=method)) +
        stat_summary(fun.data=sum_func, geom="pointrange", width=0.05, color="gray35",fill="gray35") +
        facet_grid(cols=vars(clock), rows=vars(phi_idx), scales="free_y", labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
        theme_minimal() +
        labs(x="Method", y="Branch Score Distance") +
        ggtitle("dist")+
        scale_y_continuous(transform=scales::transform_modulus(0), breaks=mod_breaks(lambda = 0, prettify = T))+
        theme(
            axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
            plot.margin = margin(0, 0, 0, 0, "cm"),
            panel.grid.major = element_blank(), 
            axis.line = element_line(size=rel(0.2), colour = "grey80"),
            plot.title = element_text(hjust = 0.5,size=rel(1.0)))
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_dists.pdf",8,8)
plot(p_dists)
dev.off()
#> quartz_off_screen 
#>                 2
p_dists

Print bias/variance

df_summs %>% 
    group_by(method, variable) %>%
    summarise(bias = mean(value), 
        rmse = sqrt(mean(value**2)),
        q5=quantile(value, 0.05)[[1]],
        q95=quantile(value, 0.95)[[1]],
        iqr90=quantile(value, 0.95)[[1]]-quantile(value, 0.05)[[1]]) %>%
        print(n=100)
#> `summarise()` has grouped output by 'method'. You can override using the
#> `.groups` argument.
#> # A tibble: 12 × 7
#> # Groups:   method [3]
#>    method   variable                bias        rmse       q5      q95    iqr90
#>    <chr>    <chr>                  <dbl>       <dbl>    <dbl>    <dbl>    <dbl>
#>  1 lsd2     dist               20.4          25.5     6.62     54.3     47.7   
#>  2 lsd2     rel_branch_count   -0.217         0.232  -0.344    -0.0575   0.287 
#>  3 lsd2     rel_height          0.506         0.707  -0.00635   1.61     1.62  
#>  4 lsd2     rel_length          0.438         0.637  -0.0139    1.53     1.54  
#>  5 ours     dist               10.6          10.9     6.98     15.6      8.65  
#>  6 ours     rel_branch_count    0.0160        0.0298 -0.0151    0.0621   0.0772
#>  7 ours     rel_height          0.000653      0.125  -0.181     0.220    0.401 
#>  8 ours     rel_length         -0.00703       0.0776 -0.129     0.122    0.251 
#>  9 treetime dist             5375.       112331.      6.88    168.     161.    
#> 10 treetime rel_branch_count   -0.0922        0.126  -0.189     0.0911   0.280 
#> 11 treetime rel_height        142.         3185.     -0.361     3.76     4.12  
#> 12 treetime rel_length        192.         4479.     -0.209     2.66     2.87